// file: system/mm/memcache.c
// autor: jiangxinpeng
// time: 2021.2.4
// copyright: (C) 2020-2050 by Jiang xinpeng. All rights reserved.

#include <arch/page.h>
#include <arch/interrupt.h>
#include <arch/mem_pool.h>
#include <arch/pymem.h>
#include <os/mutexlock.h>
#include <os/debug.h>
#include <os/driver.h>
#include <lib/bitmap.h>
#include <lib/math.h>
#include <lib/string.h>
#include <lib/math.h>
#include <lib/assert.h>

PRIVATE mem_cache_size_t cache_size[] =
    {
#if PAGE_SIZE == 0x1000
        {32, NULL}, // 32*2^(N-1)
#endif
        {64, NULL},
        {128, NULL},
        {256, NULL},
        {512, NULL},
        // up are low 1KB
        {1024, NULL},   // 1024*2^(N-1)
        {2048, NULL},   // 2KB
        {4096, NULL},   // 4KB
        {8192, NULL},   // 8KB
        {16384, NULL},  // 16KB
        {32768, NULL},  // 32KB
        {65536, NULL},  // 64KB
        {131072, NULL}, // 128KB
// in general situation,max support 128 KB mem alloc
#ifdef CONFIG_LARGE_MEM
        {262144, NULL},  // 256LB
        {524288, NULL},  // 512KB
        {1048576, NULL}, // 1MB
        {2097152, NULL}, // 2MB
        {4194304, NULL}, // 4MB
#endif
        {0, NULL} // index eof flags
};

#if PAGE_SIZE == 4096
#ifdef CONFIG_LARGE_MEM
#define MEMCACHE_NUM_MAX 18
#else
#define MEMCACHE_NUM_MAX 13
#endif
#else
#ifdef CONFIG_LARGE_MEM
#define MEMCACHE_NUM_MAX 17
#else
#define MEMCACHE_NUM_MAX 12
#endif
#endif

// sys mem cache table
mem_cache_t mem_cache_table[MEMCACHE_NUM_MAX];

// large mem object list
list_t large_object_list;
// large object lock
DEFINE_MUTEX_LOCK(lock);

// kernel preson pages
static uint8_t *kernel_pages;

int MemCacheInit(mem_cache_t *cache, char *name, size_t size, flags_t flag)
{
    uint64_t group_size;

    if (!size)
        return -1;

    // init cache group list
    list_init(&cache->full_groups);
    list_init(&cache->free_groups);
    list_init(&cache->partial_groups);
    // according cache size,set cache object num
    if (size < 1024) // if size < 1024 bytes ,we put in single pages,save memory
    {
        // group size aligned to 8
        uint64_t group_size = ALIGH_WITH(sizeof(mem_group_t), 8);
        // reserved 32 bytes to bitmap
        uint64_t left_size = PAGE_SIZE - group_size - MEM_CACHE_PAGERESERVED_SIZE;

        cache->object_count = do_div64(left_size, size);
    }
    else
    {
        if (size < 128 * 1024) // if size is 128KB,put in 1MB
        {
            cache->object_count = do_div64((1 * MB), size);
        }
        else
        {
            if (size < 4 * 1024 * 1024) // if size is 4MB, put in 4MB
            {
                cache->object_count = do_div64((4 * MB), size);
            }
            else
            {
                // above max range，cache 4 objects
                cache->object_count = 4;
            }
        }
    }
    // single cache size
    cache->object_size = size;
    cache->flags = flag;
    // clean cache name
    memset(cache->name, 0, MEM_CACHE_NAME_LEN);
    // write cache name
    strcpy(cache->name, name);

    // KPrint("[mem] init cache size %d object count %d\n", cache->object_size, cache->object_count);
    return 0;
}

// alloc page
void *MemCacheAllocPage(uint32_t count)
{
    void *page = AllocKernelPage(count);
    if (page != NULL)
    {
        // link pypage to vbase
        VbaseLinkPybase((uint32_t)KERNEL_PYBASE2VBASE(page), (uint32_t)page, KERNEL_PAGE_ATTR);
        // KPrint("[mem] alloc page count %d pybase %x vbase %x\n", count, page, KERNEL_PYBASE2VBASE(page));
        return KERNEL_PYBASE2VBASE(page);
    }
    KPrint("[mem] alloc page count %d failed!\n", count);
    return NULL;
}

// free page
void MemCacheFreePage(uint32_t page, uint32_t size)
{
    void *pypage = KERNEL_VBASE2PYBASE(page);
    uint32_t pages = DIV_ROUND_DOWN(size, PAGE_SIZE);

    if (pypage != NULL)
    {
        while (pages > 0)
        {
            // unlink vbase and pypage
            VbaseUnlinkPybase(page);
            FreeMemNode(ADDR(pypage));
            // next pages
            pypage += PAGE_SIZE;
            pages--;
        }
    }
}

int MemGroupCreate(mem_cache_t *cache, flags_t flags)
{
    mem_group_t *group;
    // alloc a pages to kernel
    kernel_pages = MemCacheAllocPage(1);
    // alloc space in kernel pages for group struct
    group = (mem_group_t *)kernel_pages;
    if (group != NULL)
    {
        // init group faild
        if (MemGroupInit(cache, group, flags) < 0)
        {
            KPrint("[mem] cache size %d init group failed!\n", cache->object_size);
            // free group pypage
            MemCacheFreePage((address_t)group, PAGE_SIZE);
            return -1;
        }
        return 0;
    }
    return -1;
}

int MemGroupInit(mem_cache_t *cache, mem_group_t *group, flags_t flags)
{
    mem_node_t *node;
    uint32_t pages;

    // add to cache free group list
    list_add_after(&group->list, &cache->free_groups);
    // alloc bitmap in mem_group_t after
    uint8_t *map = kernel_pages + sizeof(mem_group_t);
    // map length is cache object count bytes
    group->map.length = DIV_ROUND_UP(cache->object_count, 8);
    // map data area
    group->map.bits = map;
    // init bitmaps
    init_bitmap(&group->map);

    // if cache size is 1024,according before set,we put in single page
    if (cache->object_size < 1024)
    {
        // single pages
        pages = 1;
        // group,map and object are in single pages
        group->object = map + MEM_CACHE_PAGERESERVED_SIZE;

        node = Pybase2Memnode(KERNEL_VBASE2PYBASE(group));
        node->mem_cache = cache;
        node->mem_group = group;
        // KPrint("%s: mem group alloc object base at %x\n", __func__, group->object);
    }
    else
    {
        // figure new pages
        pages = (cache->object_count * cache->object_size) / PAGE_SIZE;
        // alloc pages to group object
        group->object = MemCacheAllocPage(pages);
        // KPrint("%s: group alloc pages %d for group object\n", __func__, pages);
        //  alloc faile
        if (group->object == NULL)
        {
            KPrint("[mem] %s: group alloc mem for object failed!\n", __func__);
            return -1;
        }

        // set object node info
        for (int i = 0; i < pages; i++)
        {
            // transfer group object to node
            node = Pybase2Memnode(ADDR(KERNEL_VBASE2PYBASE(group->object + i * PAGE_SIZE)));
            // set node mem group and mem cache
            node->mem_group = group;
            node->mem_cache = cache;
        }
        // KPrint("%s: mem group alloc object base at %x pages %d\n", __func__, group->object, pages);
    }

    // init group count
    group->using_count = 0;
    group->free_count = cache->object_count;
    group->flags = flags;

    // KPrint("[mem] %s: group alloc sucess! base %x pages %d\n", __func__, group->object, pages);
    return 0;
}

static int MemCacheMake()
{
    mem_cache_size_t *cachesz = cache_size;
    mem_cache_t *cache = &mem_cache_table[0];

    while (cachesz->cache_size != 0)
    {
        if (MemCacheInit(cache, "mem_cache", cachesz->cache_size, 0) != 0)
        {
            return -1;
        }
        // add cache to cache size point area
        cachesz->mem_cache = cache;
        // init next cache size
        cachesz++;
        // point next cache
        cache++;

        // KPrint("[mem] make cache size %d\n", cachesz->cache_size);
    }
    return 0;
}

// alloc a free object on group
void *MemCacheAlloc(mem_cache_t *cache, mem_group_t *group)
{
    intptr_t object = 0;
    int idx = 0;
    list_t *node = NULL;

    // alloc free bit
    idx = bitmap_find_free(&group->map);
    if (idx != -1)
    {
        // KPrint("[mem] alloc free idx %d\n", idx);
        //  set bits used
        bitmap_set(&group->map, 1, idx);
        // set group status
        group->free_count--;
        group->using_count++;
        // point to targe object
        object = group->object + idx * cache->object_size;
        // KPrint("[mem] group %x object base %x object size %d \n", group, group->object, cache->object_size);
        //  check group alloc status
        if (group->free_count <= 0)
        {
            list_del_init(&group->list);
            // add to cache full list
            list_add_tail(&group->list, &cache->full_groups);
        }
        // KPrint("[mem] %s: alloc mem at %x\n", __func__, object);
        return object;
    }

    assert(list_find(&group->list, &cache->partial_groups));
    assert(!list_find(&group->list, &cache->free_groups));
    assert(!list_find(&group->list, &cache->full_groups));
    KPrint("[mem] %s: alloc mem failed!\n", __func__);
    return NULL;
}

void *MemCacheAllocObject(mem_cache_t *cache)
{
    mem_group_t *group = NULL;
    list_t *node = NULL;
    uint8_t *object = NULL;

    // start try alloc
    if (list_empty(&cache->partial_groups))
    {
        // list is empty
        if (list_empty(&cache->free_groups))
        {
            //  create new group
            if (MemGroupCreate(cache, 0) < 0)
            {
                KPrint("[mem] create group for cache size %d failed!\n", cache->object_size);
                return NULL;
            }
        }
        // get node
        node = cache->free_groups.next;
        // del and add to cache partial used group
        list_del_init(node);
        list_add_tail(node, &cache->partial_groups);
        assert(list_empty(&cache->free_groups));
    }
    // get group according node
    group = list_first_owner(&cache->partial_groups, mem_group_t, list);
    //  alloc object in group using bitmap
    object = MemCacheAlloc(cache, group);
    return object;
}

void MemCacheFreeObject(mem_cache_t *cache, void *object)
{
    int index;
    mem_node_t *node = Pybase2Memnode(ADDR(KERNEL_VBASE2PYBASE(object)));
    mem_group_t *group = node->mem_group;
    if (group != NULL)
    {
        // figrue index
        index = ((uint8_t *)object - group->object) / cache->object_size;
        // above range
        if (!(index >= 0 && index < group->map.length * BITS_PER_BYTE))
        {
            Panic("[mem] map index %d above range!\n", index);
        }
        // clear
        bitmap_set(&group->map, 0, index);
        // set group
        group->free_count++;
        group->using_count--;
        // check group status
        if (!group->using_count)
        {
            // group free
            list_del_init(&group->list);
            list_add_tail(&group->list, &cache->free_groups);
        }
        else
        {
            if (group->using_count == cache->object_count)
            {
                // partail using
                list_del_init(&group->list);
                list_add_tail(&group->list, &cache->partial_groups);
            }
        }
        // KPrint("[mem] free mem at group %x object base %x\n", group, group->object);
    }
}

// kernel mem alloc
void *KMemAlloc(size_t size)
{
    mem_cache_size_t *_cache = NULL;
    void *object = NULL;
    large_mem_object_t *large_object = NULL;

    // object too large
    if (size > MEM_CACHE_SIZE_MAX)
    {
        if (size > MEM_OBJECT_SIZE_MAX)
        {
            KPrint("[mem] request mem size above max %d\n", MEM_OBJECT_SIZE_MAX);
            return NULL;
        }
#ifdef MEM_DEBUG
        KPrint("[mem] try alloc a large mem object size %d\n", size);
#endif
        large_object = KMemAlloc(sizeof(large_mem_object_t));
        if (large_object != NULL)
        {
            large_object->size = size;
            large_object->vbase = (uint32_t)MemCacheAllocPage(DIV_ROUND_UP(size, PAGE_SIZE)); // alloc assign pages
            if ((void *)large_object->vbase != NULL)
            {
                // add to list
                list_add_after(&large_object->list, &large_object_list);
            }
            // return object vbase
            object = large_object->vbase;
#ifdef MEM_DEBUG
            KPrint("[mem] alloc object at %x size %d\n", object, size);
#endif
        }
    }
    else
    {
        // general mem size alloc
        _cache = &cache_size[0];
        // search fit cache size
        while (_cache->cache_size)
        {
            if (size >= _cache->cache_size)
            {
                _cache++;
                continue;
            }
            break;
        }

        // cache size no valid,return NULL point
        if (_cache->cache_size == 0)
        {
            KPrint("[mem] no valid cache size\n");
            return NULL;
        }

#ifdef MEM_DEBUG
        KPrint("[mem] try alloc object on cache size %d\n", _cache->cache_size);
#endif
        //   alloc object
        object = MemCacheAllocObject(_cache->mem_cache);
    }
#ifdef MEM_DEBUG
    KPrint("[mem] alloc mem address at %x size %d\n", object, size);
#endif
    return object;
}

// kernel mem free
void KMemFree(void *object)
{
    mem_cache_t *cache = NULL;
    mem_node_t *node = NULL;
    large_mem_object_t *large_object = NULL;

    if (object != NULL)
    {
        node = Pybase2Memnode((uint32_t)KERNEL_VBASE2PYBASE(object));
        if (node != NULL)
        {
            cache = node->mem_cache;
            // node in cache
            if (cache != NULL)
            {
                // KPrint("[mem] free object on cache size %d group object %x\n", cache->object_size, node->mem_group->object);
                MemCacheFreeObject(cache, object);
                return;
            }
        }

        // large mem object
        list_traversal_all_owner_to_next(large_object, &large_object_list, list)
        {
            if (large_object->vbase == (uint32_t)object)
            {
                list_del(&large_object->list);
                MemCacheFreePage(large_object->vbase, large_object->size);
                KMemFree(large_object);
                return;
            }
        }
    }
}

void MemCachesInit()
{
    if (MemCacheMake() < 0)
    {
        KPrint("[mem_cache] make cache failed!\n");
        return;
    }
    KPrint("[mem cache] make cache finished!\n");

    // MemCacheTest();
}

void MemCacheTest()
{
    int i;
    uint32_t base;

    for (i = 0; i < cache_size[0].mem_cache->object_count * 5; i++)
    {
        base = KMemAlloc(cache_size[0].mem_cache->object_size);
        if (!base)
            Panic("[test] alloc error");
        KMemFree(base);
        base = KMemAlloc(cache_size[0].mem_cache->object_size);
        if (!base)
            Panic("[test] alloc error!\n");
    }
}